{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dYwg9btt1wJH"
   },
   "source": [
    "# BERT\n",
    "\n",
    "**Bidirectional Encoder Representations from Transformers.**\n",
    "\n",
    "_ | _\n",
    "- | -\n",
    "![alt](https://pytorch.org/assets/images/bert1.png) | ![alt](https://pytorch.org/assets/images/bert2.png)\n",
    "\n",
    "\n",
    "### **Overview**\n",
    "\n",
    "BERT was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin *et al.* The model is based on the Transformer architecture introduced in [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani *et al.* and has led to significant improvements in a wide range of natural language tasks.\n",
    "\n",
    "At the highest level, BERT maps from a block of text to a numeric vector which summarizes the relevant information in the text.\n",
    "\n",
    "What is remarkable is that numeric summary is sufficiently informative that, for example, the numeric summary of a paragraph followed by a reading comprehension question contains all the information necessary to satisfactorily answer the question.\n",
    "\n",
    "#### **Transfer Learning**\n",
    "\n",
    "BERT is a great example of a paradigm called *transfer learning*, which has proved very effective in recent years. In the first step, a network is trained on an unsupervised task using massive amounts of data. In the case of BERT, it was trained to predict missing words and to detect when pairs of sentences are presented in reversed order using all of Wikipedia. This was initially done by Google, using intense computational resources.\n",
    "\n",
    "Once this network has been trained, it is then used to perform many other supervised tasks using only limited data and computational resources: for example, sentiment classification in tweets or question answering. The network is re-trained to perform these other tasks in such a way that only the final, output parts of the network are allowed to adjust by very much, so that most of the \"information'' originally learned the network is preserved. This process is called *fine tuning*."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TgWpXdSIl5KL"
   },
   "source": [
    "##Getting to know BERT\n",
    "\n",
    "BERT, and many of its variants, are made available to the public by the open source [Huggingface Transformers](https://huggingface.co/transformers/) project. This is an amazing resource, giving researchers and practitioners easy-to-use access to this technology.\n",
    "\n",
    "In order to use BERT for modeling, we simply need to download the pre-trained neural network and fine tune it on our dataset, which is illustrated below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y7hnE0rbdOmB"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "# Install Huggingface Transformers toolkit\n",
    "!pip install transformers==4.37.2    # Last version of transformers before keras 3\n",
    "!pip install shap\n",
    "!pip install tensorflow_addons\n",
    "!pip install livelossplot\n",
    "!pip install sqldf\n",
    "# !pip install auto-sklearn\n",
    "# !pip install -U scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1ku5QTraekzt"
   },
   "outputs": [],
   "source": [
    "gdrive = False  # set to True to save weights to your google drive\n",
    "if gdrive:\n",
    "    # Mount google drive so we can save stuff for later\n",
    "    from google.colab import drive\n",
    "    drive.mount('/content/gdrive')\n",
    "    weights_folder = \"/content/gdrive/MyDrive\"\n",
    "else:\n",
    "    weights_folder = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hGRPttHC5WJU"
   },
   "outputs": [],
   "source": [
    "# Load dependencies\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sqldf as sql\n",
    "import plotnine as p9\n",
    "p9.theme_set(p9.theme_bw)\n",
    "\n",
    "import statsmodels.formula.api as sm\n",
    "from transformers import TFBertModel, BertTokenizer\n",
    "\n",
    "# Formatting tools\n",
    "from pprint import pformat\n",
    "np.set_printoptions(threshold=10)\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "from tensorflow.keras import Model, Input\n",
    "from tensorflow.keras.layers import Dense, Concatenate\n",
    "import tensorflow_addons as tfa\n",
    "from tensorflow.keras import regularizers\n",
    "from livelossplot import PlotLossesKeras\n",
    "\n",
    "from sklearn.linear_model import LassoCV\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tLiPywGO0ZEd"
   },
   "outputs": [],
   "source": [
    "def ssq(x):\n",
    "    return np.inner(x, x)\n",
    "\n",
    "\n",
    "def get_r2(y, yhat):\n",
    "    resids = yhat.reshape(-1) - y\n",
    "    flucs = y - np.mean(y)\n",
    "    print('RSS: {}, TSS + MEAN^2: {}, TSS: {}, R^2: {}'.format(ssq(resids),\n",
    "                                                               ssq(y),\n",
    "                                                               ssq(flucs),\n",
    "                                                               1 - ssq(resids) / ssq(flucs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WPY0wQktrZb1"
   },
   "outputs": [],
   "source": [
    "# Load TensorFlow, and ensure GPU is present\n",
    "# The GPU will massively speed up neural network training\n",
    "device_name = tf.test.gpu_device_name()\n",
    "if device_name != '/device:GPU:0':\n",
    "    warnings.warn('GPU device not found')\n",
    "print('Found GPU at: {}'.format(device_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8aNZcJwIcL5T"
   },
   "outputs": [],
   "source": [
    "# Download text pre-processor (\"tokenizer\")\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K_ljmPI3cEQI"
   },
   "outputs": [],
   "source": [
    "# Download BERT model\n",
    "bert = TFBertModel.from_pretrained(\"bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "26mRwUFwardQ"
   },
   "source": [
    "### Tokenization\n",
    "\n",
    "The first step in using BERT (or any similar text embedding tool) is to *tokenize* the data. This step standardizes blocks of text, so that meaningless differences in text presentation don't affect the behavior of our algorithm.\n",
    "\n",
    "Typically the text is transformed into a sequence of 'tokens,' each of which corresponds to a numeric code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vnLlb5peTwqM"
   },
   "outputs": [],
   "source": [
    "# Let's try it out!\n",
    "s = \"What happens to this string?\"\n",
    "print('Original String: \\n\\\"{}\\\"\\n'.format(s))\n",
    "tensors = tokenizer(s)\n",
    "print('Numeric encoding: \\n' + pformat(tensors))\n",
    "\n",
    "# What does this mean?\n",
    "print('\\nActual tokens:')\n",
    "tokenizer.convert_ids_to_tokens(tensors['input_ids'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JJaz6eEocefa"
   },
   "source": [
    "### BERT in a nutshell\n",
    "\n",
    "Once we have our numeric tokens, we can simply plug them into the BERT network and get a numeric vector summary. Note that in applications, the BERT summary will be \"fine tuned\" to a particular task, which hasn't happened yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q1ODAgBMa3Zg"
   },
   "outputs": [],
   "source": [
    "print('Input: \"What happens to this string?\"\\n')\n",
    "\n",
    "# Tokenize the string\n",
    "tensors_tf = tokenizer(\"What happens to this string?\", return_tensors=\"tf\")\n",
    "\n",
    "# Run it through BERT\n",
    "output = bert(tensors_tf)\n",
    "\n",
    "# Inspect the output\n",
    "_shape = output['pooler_output'].shape\n",
    "print(\"\"\"Output type: {}\\n\n",
    "      Output shape: {}\\n\n",
    "      Output preview: {}\\n\"\"\".format(type(output['pooler_output']),\n",
    "                                     _shape,\n",
    "                                     pformat(output['pooler_output'].numpy())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y_CnEClsl_1p"
   },
   "source": [
    "# A practical introduction to BERT\n",
    "\n",
    "In the next part of the notebook, we are going to explore how a tool like BERT may be useful for causal inference.\n",
    "\n",
    "In particular, we are going to apply BERT to a subset of data from the Amazon marketplace consisting of roughly 10,000 listings for products in the toy category. Each product comes with a text description, a price, and a number of times reviewed (which we'll use as a proxy for demand / market share).\n",
    "\n",
    "For more information on the dataset, checkout the [Dataset README](https://github.com/CausalAIBook/MetricsMLNotebooks/blob/main/data/amazon_toys.md).\n",
    "\n",
    "**For thought**:\n",
    "What are some issues you may anticipate when using number of reviews as a proxy for demand or market share?\n",
    "\n",
    "### Getting to know the data\n",
    "\n",
    "First, we'll download and clean up the data, and do some preliminary inspection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5kzXygwH0BKw"
   },
   "outputs": [],
   "source": [
    "DATA_URL = 'https://github.com/CausalAIBook/MetricsMLNotebooks/raw/main/data/amazon_toys.csv'\n",
    "data = pd.read_csv(DATA_URL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NDgWwxUWErnh"
   },
   "outputs": [],
   "source": [
    "data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1Su5vOGhD3Df"
   },
   "outputs": [],
   "source": [
    "# Clean numeric data fields (remove all non-digit characters and parse as a numeric value)\n",
    "data['number_of_reviews'] = pd.to_numeric(data\n",
    "                                          .number_of_reviews\n",
    "                                          .str.replace(',', '')  # Remove commas\n",
    "                                          .str.replace(r\"\\D+\", ''))\n",
    "data['price'] = (data\n",
    "                 .price\n",
    "                 .str.extract(r'(\\d+\\.*\\d+)')\n",
    "                 .astype('float'))\n",
    "\n",
    "# Drop products with very few reviews\n",
    "data = data[data['number_of_reviews'] > 0]\n",
    "\n",
    "# Compute log prices\n",
    "data['ln_p'] = np.log(data.price)\n",
    "\n",
    "# Impute market shares from # of reviews\n",
    "data['ln_q'] = np.log(data['number_of_reviews'] / data['number_of_reviews'].sum())\n",
    "\n",
    "# Collect relevant text data\n",
    "data['text'] = (data[['product_name',\n",
    "                      'manufacturer',\n",
    "                      'product_description']]\n",
    "                .astype('str')\n",
    "                .agg(' | '.join, axis=1))\n",
    "\n",
    "#  Drop irrelevant data and inspect\n",
    "data = data[['text', 'ln_p', 'ln_q', 'amazon_category_and_sub_category']]\n",
    "data = data.dropna()\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lovFEHaWp4lC"
   },
   "outputs": [],
   "source": [
    "# Text lengths\n",
    "data['text_num_words'] = data['text'].str.split().apply(len)\n",
    "print(np.nanquantile(data['text_num_words'], 0.99))\n",
    "(p9.ggplot(data, p9.aes('text_num_words')) + p9.geom_density())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CDlHPQZfcv7I"
   },
   "source": [
    "Let's make a two-way scatter plot of prices and (proxied) market shares."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dNujhir1q_0N"
   },
   "outputs": [],
   "source": [
    "(p9.ggplot(data, p9.aes('ln_p', 'ln_q')) + p9.geom_point() + p9.stat_smooth(color=\"red\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "t7axTjsnq8qI"
   },
   "outputs": [],
   "source": [
    "(p9.ggplot(data, p9.aes('ln_p', 'ln_q')) + p9.stat_smooth(color=\"red\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aZPOUT5SvR4o"
   },
   "outputs": [],
   "source": [
    "result = sm.ols('ln_q ~ ln_p ', data=data).fit()\n",
    "print('Elasticity: {}, SE: {}, R2: {}'.format(result.params['ln_p'], result.bse['ln_p'], result.rsquared_adj))\n",
    "result.conf_int(alpha=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cr2Ha215l4QK"
   },
   "source": [
    "Let's begin with a simple prediction task. We will discover how well can we explain the price of these products using their textual descriptions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pGt-G-qpciGd"
   },
   "outputs": [],
   "source": [
    "main_ind, test_ind = train_test_split(np.arange(data.shape[0]), test_size=.2, shuffle=True, random_state=124)\n",
    "main = data.iloc[main_ind]\n",
    "\n",
    "train_ind, val_ind = train_test_split(np.arange(main.shape[0]), test_size=0.25, random_state=124)  # 0.25 x 0.8 = 0.2\n",
    "\n",
    "train = main.iloc[train_ind]\n",
    "val = main.iloc[val_ind]\n",
    "holdout = data.iloc[test_ind]\n",
    "\n",
    "tensors = tokenizer(\n",
    "    list(train[\"text\"]),\n",
    "    padding=True,\n",
    "    truncation=True,\n",
    "    max_length=128,\n",
    "    return_tensors=\"tf\")\n",
    "\n",
    "val_tensors = tokenizer(\n",
    "    list(val[\"text\"]),\n",
    "    padding=True,\n",
    "    truncation=True,\n",
    "    max_length=128,\n",
    "    return_tensors=\"tf\")\n",
    "\n",
    "# Preprocess holdout sample\n",
    "tensors_holdout = tokenizer(\n",
    "    list(holdout[\"text\"]),\n",
    "    padding=True,\n",
    "    truncation=True,\n",
    "    max_length=128,\n",
    "    return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XQ4DMJQ0drZm"
   },
   "outputs": [],
   "source": [
    "ln_p = train[\"ln_p\"]\n",
    "ln_q = train[\"ln_q\"]\n",
    "val_ln_p = val[\"ln_p\"]\n",
    "val_ln_q = val[\"ln_q\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gPYEMuKZ7ylj"
   },
   "source": [
    "# Using BERT as Feature Extractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AJu0FlEcu6Zj"
   },
   "outputs": [],
   "source": [
    "input_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "token_type_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "attention_mask = Input(shape=(128,), dtype=tf.int32)\n",
    "\n",
    "Z = bert(input_ids, token_type_ids, attention_mask)[1]\n",
    "\n",
    "embedding_model = Model([input_ids, token_type_ids, attention_mask], Z)\n",
    "\n",
    "embeddings = embedding_model.predict([tensors['input_ids'], tensors['token_type_ids'], tensors['attention_mask']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-SSqzkf3vx8I"
   },
   "outputs": [],
   "source": [
    "lcv = make_pipeline(StandardScaler(), LassoCV(cv=KFold(n_splits=5, shuffle=True, random_state=123), random_state=123))\n",
    "lcv.fit(embeddings, ln_p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hLEZ1JL2weU9"
   },
   "outputs": [],
   "source": [
    "embeddings_val = embedding_model.predict([val_tensors['input_ids'],\n",
    "                                          val_tensors['token_type_ids'],\n",
    "                                          val_tensors['attention_mask']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4Aez3Il1wp4x"
   },
   "outputs": [],
   "source": [
    "get_r2(val_ln_p, lcv.predict(embeddings_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u1aRxiWsBO-V"
   },
   "outputs": [],
   "source": [
    "embeddings_holdout = embedding_model.predict([tensors_holdout['input_ids'],\n",
    "                                              tensors_holdout['token_type_ids'],\n",
    "                                              tensors_holdout['attention_mask']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "899QoDe0BVId"
   },
   "outputs": [],
   "source": [
    "get_r2(holdout['ln_p'], lcv.predict(embeddings_holdout))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RwzWK34wBY9I"
   },
   "outputs": [],
   "source": [
    "ln_p_hat_holdout = lcv.predict(embeddings_holdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mOc1_C5p7ta7"
   },
   "source": [
    "# Linear Probing: Training Only Final Layer after BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ck1xqRIrmx8I"
   },
   "outputs": [],
   "source": [
    "# Now let's prepare our model\n",
    "tf.keras.utils.set_random_seed(123)\n",
    "\n",
    "input_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "token_type_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "attention_mask = Input(shape=(128,), dtype=tf.int32)\n",
    "\n",
    "# First we compute the text embedding\n",
    "Z = bert(input_ids, token_type_ids, attention_mask)\n",
    "\n",
    "for layer in bert.layers:\n",
    "    layer.trainable = False\n",
    "    for w in layer.weights:\n",
    "        w._trainable = False\n",
    "\n",
    "# We want the \"pooled / summary\" embedding, not individual word embeddings\n",
    "Z = Z[1]\n",
    "\n",
    "# Then we do a regular regression\n",
    "ln_p_hat = Dense(1, activation='linear',\n",
    "                 kernel_regularizer=regularizers.L2(1e-3))(Z)\n",
    "\n",
    "PricePredictionNetwork = Model([input_ids,\n",
    "                                token_type_ids,\n",
    "                                attention_mask], ln_p_hat)\n",
    "PricePredictionNetwork.compile(\n",
    "    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),\n",
    "    loss=tf.keras.losses.MeanSquaredError(),\n",
    "    metrics=tfa.metrics.RSquare(),\n",
    ")\n",
    "PricePredictionNetwork.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XhTREb3NcZhH"
   },
   "outputs": [],
   "source": [
    "tf.keras.utils.set_random_seed(123)\n",
    "earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
    "modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(weights_folder, \"pweights.hdf5\"),\n",
    "                                                     monitor='val_loss',\n",
    "                                                     save_best_only=True,\n",
    "                                                     save_weights_only=True)\n",
    "\n",
    "PricePredictionNetwork.fit(x=[tensors['input_ids'],\n",
    "                              tensors['token_type_ids'],\n",
    "                              tensors['attention_mask']],\n",
    "                           y=ln_p,\n",
    "                           validation_data=([val_tensors['input_ids'],\n",
    "                                             val_tensors['token_type_ids'],\n",
    "                                             val_tensors['attention_mask']], val_ln_p),\n",
    "                           epochs=5,\n",
    "                           callbacks=[earlystopping, modelcheckpoint,\n",
    "                                      PlotLossesKeras(groups={'train_loss': ['loss'],\n",
    "                                                              'train_rsq':['r_square'],\n",
    "                                                              'val_loss': ['val_loss'],\n",
    "                                                              'val_rsq': ['val_r_square']})],\n",
    "                           batch_size=16,\n",
    "                           shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MyFxR5GC8C3K"
   },
   "source": [
    "# Fine Tuning starting from the Linear Probing Trained Weights\n",
    "\n",
    "Now we train the whole network, initializing the weights based on the result of the linear probing phase in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NzWCkTY87luH"
   },
   "outputs": [],
   "source": [
    "# Now let's prepare our model\n",
    "tf.keras.utils.set_random_seed(123)\n",
    "\n",
    "input_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "token_type_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "attention_mask = Input(shape=(128,), dtype=tf.int32)\n",
    "\n",
    "# First we compute the text embedding\n",
    "Z = bert(input_ids, token_type_ids, attention_mask)\n",
    "\n",
    "for layer in bert.layers:\n",
    "    layer.trainable = True\n",
    "    for w in layer.weights:\n",
    "        w._trainable = True\n",
    "\n",
    "# We want the \"pooled / summary\" embedding, not individual word embeddings\n",
    "Z = Z[1]\n",
    "\n",
    "# Then we do a regularized linear regression\n",
    "ln_p_hat = Dense(1, activation='linear',\n",
    "                 kernel_regularizer=regularizers.L2(1e-3))(Z)\n",
    "\n",
    "PricePredictionNetwork = Model([input_ids,\n",
    "                                token_type_ids,\n",
    "                                attention_mask], ln_p_hat)\n",
    "PricePredictionNetwork.compile(\n",
    "    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),\n",
    "    loss=tf.keras.losses.MeanSquaredError(),\n",
    "    metrics=tfa.metrics.RSquare(),\n",
    ")\n",
    "PricePredictionNetwork.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PWauCl0T7nUo"
   },
   "outputs": [],
   "source": [
    "PricePredictionNetwork.load_weights(os.path.join(weights_folder, \"pweights.hdf5\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iDSBZWAe8nhE"
   },
   "outputs": [],
   "source": [
    "tf.keras.utils.set_random_seed(123)\n",
    "\n",
    "earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
    "modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(weights_folder, \"pweights.hdf5\"),\n",
    "                                                     monitor='val_loss',\n",
    "                                                     save_best_only=True,\n",
    "                                                     save_weights_only=True)\n",
    "\n",
    "PricePredictionNetwork.fit(x=[tensors['input_ids'],\n",
    "                              tensors['token_type_ids'],\n",
    "                              tensors['attention_mask']],\n",
    "                           y=ln_p,\n",
    "                           validation_data=([val_tensors['input_ids'],\n",
    "                                             val_tensors['token_type_ids'],\n",
    "                                             val_tensors['attention_mask']], val_ln_p),\n",
    "                           epochs=10,\n",
    "                           callbacks=[earlystopping, modelcheckpoint,\n",
    "                                      PlotLossesKeras(groups={'train_loss': ['loss'],\n",
    "                                                              'train_rsq':['r_square'],\n",
    "                                                              'val_loss': ['val_loss'],\n",
    "                                                              'val_rsq': ['val_r_square']})],\n",
    "                           batch_size=16,\n",
    "                           shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wchpbXoqBAJu"
   },
   "outputs": [],
   "source": [
    "PricePredictionNetwork.load_weights(os.path.join(weights_folder, \"pweights.hdf5\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jpUmDHYfkJEZ"
   },
   "outputs": [],
   "source": [
    "# Compute predictions\n",
    "ln_p_hat_holdout = PricePredictionNetwork.predict([tensors_holdout['input_ids'],\n",
    "                                                   tensors_holdout['token_type_ids'],\n",
    "                                                   tensors_holdout['attention_mask']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "g_XK81hpkQMN"
   },
   "outputs": [],
   "source": [
    "print('Neural Net R^2, Price Prediction:')\n",
    "get_r2(holdout['ln_p'], ln_p_hat_holdout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GR4QP4DJPQk0"
   },
   "outputs": [],
   "source": [
    "plt.hist(ln_p_hat_holdout)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bafy9ftcoBed"
   },
   "source": [
    "Now, let's go one step further and construct a DML estimator of the price elasticity in a partially linear model. In particular, we will model market share $q_i$ as\n",
    "$$\\ln q_i = \\alpha + \\beta \\ln p_i + \\psi(d_i) + \\epsilon_i,$$ where $d_i$ denotes the description of product $i$ and $\\psi$ is the composition of text embedding and a linear layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qiteu6FaoctV"
   },
   "outputs": [],
   "source": [
    "# Build the quantity prediction network\n",
    "tf.keras.utils.set_random_seed(123)\n",
    "\n",
    "# Initialize new BERT model from original\n",
    "bert2 = TFBertModel.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "# for layer in bert2.layers:\n",
    "#     layer.trainable=False\n",
    "#     for w in layer.weights: w._trainable=False\n",
    "\n",
    "# Define inputs\n",
    "input_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "token_type_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "attention_mask = Input(shape=(128,), dtype=tf.int32)\n",
    "\n",
    "# First we compute the text embedding\n",
    "Z = bert2(input_ids, token_type_ids, attention_mask)\n",
    "\n",
    "# We want the \"pooled / summary\" embedding, not individual word embeddings\n",
    "Z = Z[1]\n",
    "\n",
    "ln_q_hat = Dense(1, activation='linear', kernel_regularizer=regularizers.L2(1e-3))(Z)\n",
    "\n",
    "# Compile model and optimization routine\n",
    "QuantityPredictionNetwork = Model([input_ids,\n",
    "                                   token_type_ids,\n",
    "                                   attention_mask], ln_q_hat)\n",
    "QuantityPredictionNetwork.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),\n",
    "                                  loss=tf.keras.losses.MeanSquaredError(),\n",
    "                                  metrics=tfa.metrics.RSquare())\n",
    "QuantityPredictionNetwork.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aaxHV0gGMqpw"
   },
   "outputs": [],
   "source": [
    "# Fit the quantity prediction network in the main sample\n",
    "tf.keras.utils.set_random_seed(123)\n",
    "\n",
    "earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
    "modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(weights_folder, \"qweights.hdf5\"),\n",
    "                                                     monitor='val_loss',\n",
    "                                                     save_best_only=True,\n",
    "                                                     save_weights_only=True)\n",
    "\n",
    "QuantityPredictionNetwork.fit([tensors['input_ids'],\n",
    "                               tensors['token_type_ids'],\n",
    "                               tensors['attention_mask']],\n",
    "                              ln_q,\n",
    "                              validation_data=([val_tensors['input_ids'],\n",
    "                                                val_tensors['token_type_ids'],\n",
    "                                                val_tensors['attention_mask']], val_ln_q),\n",
    "                              epochs=10,\n",
    "                              callbacks=[earlystopping, modelcheckpoint,\n",
    "                                         PlotLossesKeras(groups={'train_loss': ['loss'],\n",
    "                                                                 'train_rsq':['r_square'],\n",
    "                                                                 'val_loss': ['val_loss'],\n",
    "                                                                 'val_rsq': ['val_r_square']})],\n",
    "                              batch_size=16,\n",
    "                              shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TfyQV3lw-xf2"
   },
   "outputs": [],
   "source": [
    "QuantityPredictionNetwork.load_weights(os.path.join(weights_folder, \"qweights.hdf5\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YADpNj0jMygZ"
   },
   "outputs": [],
   "source": [
    "# Predict in the holdout sample, residualize and regress\n",
    "ln_q_hat_holdout = QuantityPredictionNetwork.predict([tensors_holdout['input_ids'],\n",
    "                                                      tensors_holdout['token_type_ids'],\n",
    "                                                      tensors_holdout['attention_mask']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jh4criU1hGIP"
   },
   "outputs": [],
   "source": [
    "print('Neural Net R^2, Quantity Prediction:')\n",
    "get_r2(holdout['ln_q'], ln_q_hat_holdout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ir-_yAfkPM6f"
   },
   "outputs": [],
   "source": [
    "# Compute residuals\n",
    "r_p = holdout[\"ln_p\"] - ln_p_hat_holdout.reshape((-1,))\n",
    "r_q = holdout[\"ln_q\"] - ln_q_hat_holdout.reshape((-1,))\n",
    "\n",
    "# Regress to obtain elasticity estimate\n",
    "beta = np.mean(r_p * r_q) / np.mean(r_p * r_p)\n",
    "\n",
    "# standard error on elastiticy estimate\n",
    "se = np.sqrt(np.mean((r_p * r_q)**2) / (np.mean(r_p * r_p)**2) / holdout[\"ln_p\"].size)\n",
    "\n",
    "print('Elasticity of Demand with Respect to Price: {}'.format(beta))\n",
    "print('Standard Error: {}'.format(se))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VCqeRTB_BNEH"
   },
   "source": [
    "# Heterogeneous Elasticities within Major Product Categories\n",
    "\n",
    "We now look at the major product categories that have many products, and we investigate whether the \"within group\" price elasticities vary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XRUZEXqc8HPG"
   },
   "outputs": [],
   "source": [
    "holdout['category'] = holdout['amazon_category_and_sub_category'].str.split('>').apply(lambda x: x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ymWJv4Ej7lt9"
   },
   "outputs": [],
   "source": [
    "# Elasticity within the main product categories\n",
    "sql.run(\"\"\"\n",
    "  SELECT category, COUNT(*)\n",
    "  FROM holdout\n",
    "  GROUP BY 1\n",
    "  HAVING COUNT(*)>=100\n",
    "  ORDER BY 2 desc\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E3ncPtwt8nJi"
   },
   "outputs": [],
   "source": [
    "main_cats = sql.run(\"\"\"\n",
    "  SELECT category\n",
    "  FROM holdout\n",
    "  GROUP BY 1\n",
    "  HAVING COUNT(*)>=100\n",
    "\"\"\")['category']\n",
    "\n",
    "dfs = []\n",
    "for cat in main_cats:\n",
    "    r_p = holdout[holdout['category'] == cat][\"ln_p\"] - ln_p_hat_holdout.reshape((-1,))[holdout['category'] == cat]\n",
    "    r_q = holdout[holdout['category'] == cat][\"ln_q\"] - ln_q_hat_holdout.reshape((-1,))[holdout['category'] == cat]\n",
    "    # Regress to obtain elasticity estimate\n",
    "    beta = np.mean(r_p * r_q) / np.mean(r_p * r_p)\n",
    "\n",
    "    # standard error on elastiticy estimate\n",
    "    se = np.sqrt(np.mean((r_p * r_q)**2) / (np.mean(r_p * r_p)**2) / holdout[\"ln_p\"].size)\n",
    "\n",
    "    df = pd.DataFrame({'point': beta, 'se': se, 'lower': beta - 1.96 * se, 'upper': beta + 1.96 * se}, index=[0])\n",
    "    df['category'] = cat\n",
    "    df['N'] = holdout[holdout['category'] == cat].shape[0]\n",
    "    dfs.append(df)\n",
    "\n",
    "df = pd.concat(dfs)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QFTjf9vP6Zfu"
   },
   "source": [
    "## Clustering Products\n",
    "\n",
    "In this final part of the notebook, we'll illustrate how the BERT text embeddings can be used to cluster products based on their  descriptions.\n",
    "\n",
    "Intiuitively, our neural network has now learned which aspects of the text description are relevant to predict prices and market shares.\n",
    "We can therefore use the embeddings produced by our network to cluster products, and we might expect that the clusters reflect market-relevant information.\n",
    "\n",
    "In the following block of cells, we compute embeddings using our learned models and cluster them using $k$-means clustering with $k=10$. Finally, we will explore how the estimated price elasticity differs across clusters.\n",
    "\n",
    "### Overview of **$k$-means clustering**\n",
    "The $k$-means clustering algorithm seeks to divide $n$ data vectors into $k$ groups, each of which contain points that are \"close together.\"\n",
    "\n",
    "In particular, let $C_1, \\ldots, C_k$ be a partitioning of the data into $k$ disjoint, nonempty subsets (clusters), and define\n",
    "$$\\bar{C_i}=\\frac{1}{\\#C_i}\\sum_{x \\in C_i} x$$\n",
    "to be the *centroid* of the cluster $C_i$. The $k$-means clustering score $\\mathrm{sc}(C_1 \\ldots C_k)$ is defined to be\n",
    "$$\\mathrm{sc}(C_1 \\ldots C_k) = \\sum_{i=1}^k \\sum_{x \\in C_i} \\left(x - \\bar{C_i}\\right)^2.$$\n",
    "\n",
    "The $k$-means clustering is then defined to be any partitioning $C^*_1 \\ldots C^*_k$ that minimizes the score $\\mathrm{sc}(-)$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Mc7I00JPK6wJ"
   },
   "outputs": [],
   "source": [
    "# STEP 1: Compute embeddings\n",
    "\n",
    "input_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "token_type_ids = Input(shape=(128,), dtype=tf.int32)\n",
    "attention_mask = Input(shape=(128,), dtype=tf.int32)\n",
    "\n",
    "Y1 = bert(input_ids, token_type_ids, attention_mask)[1]\n",
    "Y2 = bert2(input_ids, token_type_ids, attention_mask)[1]\n",
    "Y = Concatenate()([Y1, Y2])\n",
    "\n",
    "embedding_model = Model([input_ids, token_type_ids, attention_mask], Y)\n",
    "\n",
    "embeddings = embedding_model.predict([tensors_holdout['input_ids'],\n",
    "                                      tensors_holdout['token_type_ids'],\n",
    "                                      tensors_holdout['attention_mask']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hCG2MunU6iF-"
   },
   "source": [
    "### Dimension reduction and the **Johnson-Lindenstrauss transform**\n",
    "\n",
    "Our learned embeddings have dimension in the $1000$s, and $k$-means clustering is often an expensive operation. To improve the situation, we will use a neat trick that is used extensively in machine learning applications: the *Johnson-Lindenstrauss transform*.\n",
    "\n",
    "This trick involves finding a low-dimensional linear projection of the embeddings that approximately preserves pairwise distances.\n",
    "\n",
    "In fact, Johnson and Lindenstrauss proved a much more interesting statement: a Gaussian random matrix will *almost always* approximately preserve pairwise distances.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "afGiLR7v6ecJ"
   },
   "outputs": [],
   "source": [
    "# STEP 2 Make low-dimensional projections\n",
    "from sklearn.random_projection import GaussianRandomProjection\n",
    "\n",
    "jl = GaussianRandomProjection(eps=.25)\n",
    "embeddings_lowdim = jl.fit_transform(embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9Tl9AM3J6j3X"
   },
   "outputs": [],
   "source": [
    "# STEP 3 Compute clusters\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "k_means = KMeans(n_clusters=10)\n",
    "k_means.fit(embeddings_lowdim)\n",
    "cluster_ids = k_means.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0l7De-Do6mD0"
   },
   "outputs": [],
   "source": [
    "# STEP 4 Regress within each cluster\n",
    "\n",
    "betas = np.zeros(10)\n",
    "ses = np.zeros(10)\n",
    "\n",
    "r_p = holdout[\"ln_p\"] - ln_p_hat_holdout.reshape((-1,))\n",
    "r_q = holdout[\"ln_q\"] - ln_q_hat_holdout.reshape((-1,))\n",
    "\n",
    "for c in range(10):\n",
    "\n",
    "    r_p_c = r_p[cluster_ids == c]\n",
    "    r_q_c = r_q[cluster_ids == c]\n",
    "\n",
    "    # Regress to obtain elasticity estimate\n",
    "    betas[c] = np.mean(r_p_c * r_q_c) / np.mean(r_p_c * r_p_c)\n",
    "\n",
    "    # standard error on elastiticy estimate\n",
    "    ses[c] = np.sqrt(np.mean((r_p_c * r_q_c)**2) / (np.mean(r_p_c * r_p_c)**2) / r_p_c.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oXoe98f06njT"
   },
   "outputs": [],
   "source": [
    "# STEP 5 Plot\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.bar(range(10), betas, yerr=1.96 * ses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8tTqUNeH6sxi"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuClass": "premium",
   "machine_shape": "hm",
   "provenance": [],
   "toc_visible": true
  },
  "gpuClass": "premium",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
